#' ---
#' title: Reproduce Appendix Figure B1 (Twitter Sentiment Analysis Task, Continuous Score Prompt)
#' date: 2024-10-05
#' version: 1.0
#' ---

library(tidyverse)
library(promptr)
set.seed(42)

## Load Data -------------------

scotus_tweets <- promptr::scotus_tweets

## Function to plot results ---------------------

plot_results <- function(scotus_tweets, model){
  
  correlation <- cor(scotus_tweets$score, 
                     scotus_tweets$expert_score,
                     use = 'pairwise.complete.obs')
  
  ggplot(data = scotus_tweets,
         mapping = aes(
           x = expert_score,
           y = score
         )) +
    geom_jitter(width = 0.1, alpha = 0.7) +
    labs(x = 'Hand-Coded Sentiment Score',
         y = paste0(model, ' Sentiment Score'),
         title = paste0(model, ' ',
                        '(\U03C1 = ',
                        round(correlation, 2),
                        ')')) +
    theme_bw() +
    geom_smooth(method = 'lm', se = FALSE, color = 'gray')
}

# load GPT-4 labels from file (TRUE)
# or submit prompts (FALSE)
from_file <- TRUE

if(!from_file){
  
  # add Le Mens & Gallego instructions to dataframe
  scotus_tweets <- scotus_tweets |>
    mutate(expert_score = (expert1 + expert2 + expert3) / 3,
           instructions = case_when(case == 'masterpiece' ~ "Read this tweet posted the day after the US Supreme Court ruled in favor of a baker who refused to bake a wedding cake for a same-sex couple. What is the sentiment of this tweet? Provide your response as a score between 0 and 100 where 0 means ‘Extremely Negative’ and 100 means ‘Extremely Positive’. Respond only with this number.",
                                    case == 'mazars' ~ "Read this tweet posted the day after the US Supreme Court ruled that sitting presidents are not immune to state criminal subpoenas, and that President Trump was obliged to disclose his tax returns to the Manhattan District Attorney. What is the sentiment of this tweet? Provide your response as a score between 0 and 100 where 0 means ‘Extremely Negative’ and 100 means ‘Extremely Positive’. Respond only with this number."))
  
  # create a list of formatted prompts
  prompts <- Map(f = format_chat,
                 text = scotus_tweets$text,
                 instructions = scotus_tweets$instructions)
  
  # submit the prompts to API
  out <- complete_chat(prompts,
                       model = 'gpt-4-turbo-preview',
                       parallel = TRUE)
  
  save(out, file = 'appendix-B-tweets-updated.RData')
} else{
  load(file = 'appendix-B-tweets.RData')
}

# score is the probability-weighted output from GPT-4
scotus_tweets$score <- out |>
  lapply(filter, str_detect(token, '[0-9]')) |>
  lapply(mutate, token = as.numeric(token)) |>
  lapply(summarize, score = weighted.mean(token, probability)) |>
  unlist()

scotus_tweets <- scotus_tweets |> 
  mutate(expert_score = (expert1 + expert2 + expert3) / 3)

p <- plot_results(scotus_tweets, model = 'GPT-4')

ggsave(plot = p,
       filename = 'figure-B1.png',
       width = 8, height = 8)
